import numpy as np
from emodel_generalisation.mcmc import plot_corner
from bluepyparallel import init_parallel_factory
import yaml
from functools import partial

from pathlib import Path
import matplotlib.pyplot as plt
import pickle

# from emodel_generalisation.utils import plot_traces
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from emodel_generalisation.mcmc import load_chains
from emodel_generalisation.mcmc import save_selected_emodels
import logging
from itertools import cycle
import json
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.model.evaluation import feature_evaluation
from emodel_generalisation.utils import get_combo_hash
from emodel_generalisation.utils import get_feature_df
from datareuse import Reuse
import extra_features

if __name__ == "__main__":
    parallel_factory = init_parallel_factory("multiprocessing")
    emodel = "simplest"

    mcmc_df = load_chains("../../mcmc_run/run_df.csv", base_path="../../mcmc_run")
    mcmc_df = mcmc_df[mcmc_df.cost < 2.0]
    mask_runaway = (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"] < 0.05) & (
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 20000
    )
    # mcmc_df = mcmc_df[mask_runaway]

    plt.figure()
    plt.scatter(
        mcmc_df["parameters"]["gcabar_it2.basal"],
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"],
        rasterized=True,
        s=1,
        c="k",
        alpha=0.1,
    )
    plt.savefig("it2_scatter.pdf")
    _mcmc_df = mcmc_df.drop(
        columns=[
            c
            for c in mcmc_df.columns
            if (c[0] in "normalized_parameters")
            and (
                c[1]
                not in [
                    "g_pas.all",
                    "gcabar_it2.basal",
                    "shift_it2.somadend",
                    # "constant.distribution_increase",
                    "gbar_ican.basal",
                    "gnabar_hh2.somatic",
                ]
            )
        ]
    )
    print(_mcmc_df)
    plot_corner(
        _mcmc_df,
        feature=("features", "Step_ReboundBurst_burst.soma.v.all_burst_number"),
        filename="burst_number_corner.pdf",
        sort_params=False,
        # cmap="Greys",
    )
